[MDC-01] 必要なモジュールをインポートして、乱数のシードを設定します。
In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
np.random.seed(20160703)
tf.set_random_seed(20160703)
[MDC-02] MNISTのデータセットを用意します。
In [2]:
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
[MDC-03] フィルターに対応する Variable を用意して、入力データにフィルターとプーリング層を適用する計算式を定義します。
In [3]:
num_filters = 16
x = tf.placeholder(tf.float32, [None, 784])
x_image = tf.reshape(x, [-1,28,28,1])
W_conv = tf.Variable(tf.truncated_normal([5,5,1,num_filters],
stddev=0.1))
h_conv = tf.nn.conv2d(x_image, W_conv,
strides=[1,1,1,1], padding='SAME')
h_pool =tf.nn.max_pool(h_conv, ksize=[1,2,2,1],
strides=[1,2,2,1], padding='SAME')
[MDC-04] プーリング層からの出力を全結合層を経由してソフトマックス関数に入力する計算式を定義します。
In [4]:
h_pool_flat = tf.reshape(h_pool, [-1, 14*14*num_filters])
num_units1 = 14*14*num_filters
num_units2 = 1024
w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))
b2 = tf.Variable(tf.zeros([num_units2]))
hidden2 = tf.nn.relu(tf.matmul(h_pool_flat, w2) + b2)
w0 = tf.Variable(tf.zeros([num_units2, 10]))
b0 = tf.Variable(tf.zeros([10]))
p = tf.nn.softmax(tf.matmul(hidden2, w0) + b0)
[MDC-05] 誤差関数 loss、トレーニングアルゴリズム train_step、正解率 accuracy を定義します。
In [5]:
t = tf.placeholder(tf.float32, [None, 10])
loss = -tf.reduce_sum(t * tf.log(p))
train_step = tf.train.AdamOptimizer(0.0005).minimize(loss)
correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
[MDC-06] セッションを用意して、Variable を初期化します。
In [6]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
[MDC-07] パラメーターの最適化を4000回繰り返します。
最終的に、テストセットに対して約98%の正解率が得られます。
In [7]:
i = 0
for _ in range(4000):
i += 1
batch_xs, batch_ts = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, t: batch_ts})
if i % 100 == 0:
loss_val, acc_val = sess.run([loss, accuracy],
feed_dict={x:mnist.test.images, t:mnist.test.labels})
print ('Step: %d, Loss: %f, Accuracy: %f'
% (i, loss_val, acc_val))
saver.save(sess, 'mdc_session', global_step=i)
[MDC-08] セッション情報を保存したファイルが生成されていることを確認します。
In [8]:
!ls mdc_session*